{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run main.py\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "!mkdir -p {DATA_DIR} {NAVEC_DIR} {MODEL_DIR}\n",
    "s3 = S3()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not exists(NERUS):\n",
    "    s3.download(S3_NERUS, NERUS)\n",
    "    \n",
    "if not exists(NAVEC):\n",
    "    !wget {NAVEC_URL} -O {NAVEC}\n",
    "    s3.download(S3_TAGS_VOCAB, TAGS_VOCAB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "navec = Navec.load(NAVEC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "words_vocab = Vocab(navec.vocab.words)\n",
    "shapes_vocab = Vocab([PAD] + SHAPES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# lines = load_gz_lines(NERUS)\n",
    "# lines = log_progress(lines, total=NERUS_TOTAL)\n",
    "# items = parse_jl(lines)\n",
    "# markups = (MorphMarkup.from_json(_) for _ in items)\n",
    "\n",
    "# tags = set()\n",
    "# for markup in markups:\n",
    "#     for token in markup.tokens:\n",
    "#         tags.add(token.tag)\n",
    "\n",
    "# tags = [PAD] + sorted(tags)\n",
    "# tags_vocab = Vocab(tags)\n",
    "# tags_vocab.dump(TAGS_VOCAB)\n",
    "\n",
    "tags_vocab = Vocab.load(TAGS_VOCAB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(SEED)\n",
    "seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "word = NavecEmbedding(navec)\n",
    "shape = Embedding(\n",
    "    vocab_size=len(shapes_vocab),\n",
    "    dim=SHAPE_DIM,\n",
    "    pad_id=shapes_vocab.pad_id\n",
    ")\n",
    "emb = TagEmbedding(word, shape)\n",
    "encoder = TagEncoder(\n",
    "    input_dim=emb.dim,\n",
    "    layer_dims=LAYER_DIMS,\n",
    "    kernel_size=KERNEL_SIZE,\n",
    ")\n",
    "morph = MorphHead(encoder.dim, len(tags_vocab))\n",
    "model = Morph(emb, encoder, morph)\n",
    "\n",
    "model = model.to(DEVICE)\n",
    "\n",
    "criterion = flatten_cross_entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lines = load_gz_lines(NERUS)\n",
    "lines = log_progress(lines, total=NERUS_TOTAL)\n",
    "items = parse_jl(lines)\n",
    "markups = (MorphMarkup.from_json(_) for _ in items)\n",
    "\n",
    "encode = TagTrainEncoder(\n",
    "    words_vocab, shapes_vocab, tags_vocab,\n",
    "    seq_len=256,\n",
    "    batch_size=64,\n",
    "    shuffle_size=1000,\n",
    ")\n",
    "batches = encode(markups)\n",
    "batches = [_.to(DEVICE) for _ in batches]\n",
    "\n",
    "size = 10\n",
    "batches = {\n",
    "    TEST: batches[:size],\n",
    "    TRAIN: batches[size:]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "board = MultiBoard([\n",
    "    TensorBoard(BOARD_NAME, RUNS_DIR),\n",
    "    LogBoard()\n",
    "])\n",
    "boards = {\n",
    "    TRAIN: board.section(TRAIN_BOARD),\n",
    "    TEST: board.section(TEST_BOARD),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters(), lr=LR)\n",
    "scheduler = optim.lr_scheduler.ExponentialLR(optimizer, LR_GAMMA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "meters = {\n",
    "    TRAIN: MorphScoreMeter(),\n",
    "    TEST: MorphScoreMeter(),\n",
    "}\n",
    "\n",
    "for epoch in log_progress(range(EPOCHS)):\n",
    "    model.train()\n",
    "    for batch in log_progress(batches[TRAIN], leave=False):\n",
    "        optimizer.zero_grad()\n",
    "        batch = process_batch(model, criterion, batch)\n",
    "        batch.loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "        score = score_morph_batch(batch)\n",
    "        meters[TRAIN].add(score)\n",
    "\n",
    "    meters[TRAIN].write(boards[TRAIN])\n",
    "    meters[TRAIN].reset()\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for batch in log_progress(batches[TEST], leave=False, desc=TEST):\n",
    "            batch = process_batch(model, criterion, batch)\n",
    "            score = score_morph_batch(batch)\n",
    "            meters[TEST].add(score)\n",
    "\n",
    "        meters[TEST].write(boards[TEST])\n",
    "        meters[TEST].reset()\n",
    "    \n",
    "    scheduler.step()\n",
    "    board.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# [2020-04-11 07:43:31]    0 0.1641 01_train/01_loss\n",
    "# [2020-04-11 07:43:31]    0 0.9529 01_train/02_acc\n",
    "# [2020-04-11 07:43:32]    0 0.1220 02_test/01_loss\n",
    "# [2020-04-11 07:43:32]    0 0.9626 02_test/02_acc\n",
    "# [2020-04-11 07:45:27]    1 0.1163 01_train/01_loss\n",
    "# [2020-04-11 07:45:27]    1 0.9643 01_train/02_acc\n",
    "# [2020-04-11 07:45:27]    1 0.1118 02_test/01_loss\n",
    "# [2020-04-11 07:45:27]    1 0.9657 02_test/02_acc\n",
    "# [2020-04-11 07:47:23]    2 0.1094 01_train/01_loss\n",
    "# [2020-04-11 07:47:23]    2 0.9663 01_train/02_acc\n",
    "# [2020-04-11 07:47:23]    2 0.1073 02_test/01_loss\n",
    "# [2020-04-11 07:47:23]    2 0.9670 02_test/02_acc\n",
    "# [2020-04-11 07:49:19]    3 0.1057 01_train/01_loss\n",
    "# [2020-04-11 07:49:19]    3 0.9673 01_train/02_acc\n",
    "# [2020-04-11 07:49:19]    3 0.1045 02_test/01_loss\n",
    "# [2020-04-11 07:49:19]    3 0.9680 02_test/02_acc\n",
    "# [2020-04-11 07:51:14]    4 0.1032 01_train/01_loss\n",
    "# [2020-04-11 07:51:14]    4 0.9681 01_train/02_acc\n",
    "\n",
    "# [2020-04-11 07:51:14]    4 0.1027 02_test/01_loss\n",
    "# [2020-04-11 07:51:14]    4 0.9683 02_test/02_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tags_vocab.dump(TAGS_VOCAB)\n",
    "# model.emb.shape.dump(MODEL_SHAPE)\n",
    "# model.encoder.dump(MODEL_ENCODER)\n",
    "# model.morph.dump(MODEL_MORPH)\n",
    "\n",
    "# s3.upload(TAGS_VOCAB, S3_TAGS_VOCAB)\n",
    "# s3.upload(MODEL_SHAPE, S3_MODEL_SHAPE)\n",
    "# s3.upload(MODEL_ENCODER, S3_MODEL_ENCODER)\n",
    "# s3.upload(MODEL_MORPH, S3_MODEL_MORPH)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
